Skip to content

TP SP examples improvement #1354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

githubsgi
Copy link

Changing cuda to accelerator, adding ConmDebugMode to tensor_parallel_example.py, sequence_parallel_example.py, and log_utils.py .

Copy link

netlify bot commented Jun 11, 2025

Deploy Preview for pytorch-examples-preview canceled.

Name Link
🔨 Latest commit d16c819
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-examples-preview/deploys/685b4a62e471560008f3f150

@githubsgi githubsgi changed the title TP SP example improvement TP SP examples improvement Jun 11, 2025
output.sum().backward()
optimizer.step()
inp = torch.rand(4, 10, device=device_type)
comm_mode = CommDebugMode()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work on non cuda devices? Would be great to share some local logs of your tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gladly. Please see attached logs for H100.

Starting PyTorch TP example on rank 3.
Starting PyTorch TP example on rank 0.
06/16/2025 05:55:00 PM  Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch TP example on rank 2.
Starting PyTorch TP example on rank 1.
model ToyModel(
  (in_proj): Linear(in_features=10, out_features=32, bias=True)
  (relu): ReLU()
  (out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:55:03 PM  Tensor Parallel training starting...
06/16/2025 05:55:03 PM  Tensor Parallel iter 0 completed
 rank3 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:55:03 PM  Tensor Parallel iter 1 completed
 rank0 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank2 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank1 1 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_reduce')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_reduce: 1
  BACKWARD PASS
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32]), torch.Size([32])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([32, 10]), torch.Size([32, 10])]
              sharding: [(Shard(dim=0),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5]), torch.Size([5])]
              sharding: [(Replicate(),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.add_.Tensor
              shape: [torch.Size([5, 32]), torch.Size([5, 32])]
              sharding: [(Shard(dim=1),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:55:03 PM  Tensor Parallel iter 2 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 3 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 4 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 5 completed
06/16/2025 05:55:03 PM  Tensor Parallel iter 6 completed
06/16/2025 05:55:04 PM  Tensor Parallel iter 7 completed
06/16/2025 05:55:04 PM  Tensor Parallel iter 8 completed
06/16/2025 05:55:04 PM  Tensor Parallel iter 9 completed
06/16/2025 05:55:04 PM  Tensor Parallel training completed!
[rank0]:[W616 17:55:04.791527408 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Starting PyTorch Sequence Parallel example on rank 0.
06/16/2025 05:53:21 PM  Device Mesh created: device_mesh=DeviceMesh('cuda', [0, 1, 2, 3])
Starting PyTorch Sequence Parallel example on rank 3.
Starting PyTorch Sequence Parallel example on rank 2.
Starting PyTorch Sequence Parallel example on rank 1.
model ToyModel(
  (in_proj): Linear(in_features=10, out_features=32, bias=True)
  (relu): ReLU()
  (out_proj): Linear(in_features=32, out_features=5, bias=True)
)
06/16/2025 05:53:24 PM  Sequence Parallel training starting...
 rank2 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:53:25 PM  Sequence Parallel iter 0 completed
 rank0 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank1 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
 rank3 0 get_comm_counts defaultdict(<class 'int'>, {<OpOverloadPacket(op='c10d_functional.all_gather_into_tensor')>: 2, <OpOverloadPacket(op='c10d_functional.reduce_scatter_tensor')>: 1}) get_sharding_info() {'ToyModel.in_proj.weight': (Shard(dim=0),), 'ToyModel.in_proj.bias': (Shard(dim=0),), 'ToyModel.out_proj.weight': (Shard(dim=1),), 'ToyModel.out_proj.bias': (Replicate(),)} generate_comm_debug_tracing_table Global
  FORWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    *c10d_functional.reduce_scatter_tensor: 1
  BACKWARD PASS
    *c10d_functional.all_gather_into_tensor: 1
    ToyModel
    *module type: class '__main__.ToyModel'
      FORWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        *c10d_functional.reduce_scatter_tensor: 1
      BACKWARD PASS
        *c10d_functional.all_gather_into_tensor: 1
        ToyModel.in_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=0),)
         *bias: (Shard(dim=0),)
          FORWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([32]), torch.Size([4, 10]), torch.Size([10, 32])]
              sharding: [(Shard(dim=0),), (Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32, 10])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([32])]
              sharding: [(Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.zeros_like.default
              shape: [torch.Size([5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            **aten.mm.default
              shape: [torch.Size([32, 4]), torch.Size([4, 10])]
              sharding: [(Shard(dim=0),), (Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 32])]
              sharding: [(Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
        ToyModel.relu
        *module type: class 'torch.nn.modules.activation.ReLU'
          FORWARD PASS
          BACKWARD PASS
        ToyModel.out_proj
        *module type: class 'torch.nn.modules.linear.Linear'
        *Parameter List
         *weight: (Shard(dim=1),)
         *bias: (Replicate(),)
          FORWARD PASS
            *c10d_functional.reduce_scatter_tensor: 1
            **aten.addmm.default
              shape: [torch.Size([5]), torch.Size([4, 32]), torch.Size([32, 5])]
              sharding: [(Replicate(),), (Shard(dim=1),), (Shard(dim=0),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
          BACKWARD PASS
            *c10d_functional.all_gather_into_tensor: 1
            **aten.mm.default
              shape: [torch.Size([4, 5]), torch.Size([5, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.mm.default
              shape: [torch.Size([5, 4]), torch.Size([4, 32])]
              sharding: [(Replicate(),), (Shard(dim=1),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
            **aten.sum.dim_IntList
              shape: [torch.Size([4, 5])]
              sharding: [(Replicate(),)]
              device mesh: DeviceMesh('cuda', [0, 1, 2, 3])
 
06/16/2025 05:53:25 PM  Sequence Parallel iter 1 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 2 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 3 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 4 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 5 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 6 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 7 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 8 completed
06/16/2025 05:53:25 PM  Sequence Parallel iter 9 completed
06/16/2025 05:53:25 PM  Sequence Parallel training completed!
[rank0]:[W616 17:53:25.948217933 ProcessGroupNCCL.cpp:1516] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant on non CUDA devices, as does this API work if you use MPS or CPU?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch,accelerator works for cuda and non-cuda GPUs and accelerators. CommDebugMode is also a PyTorch feature, so should work for all devices. If not, that would be a bug.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim . if there is no more question, could it be merged ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please attach logs confirming this works on CPU?

Copy link
Author

@githubsgi githubsgi Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim , I have only used GPU's for these kind of work. The accelerator api does not support CPU's also. Also do not know whether TP and SP are supported on CPU's. If so, what distributed backend is used. The original code also would not work on CPUs as far as I can tell. In summary, these two examples were not written for CPUs. Adding CPU support will be a very significant change, if at all possible, as far as I can tell.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm confused by the goal of this PR overall

  1. Why merge a device agnostic API if the code is only expected to work on a single device? If that's the case then keeping cuda is actually clearer
  2. I'm not sure why comm_debuug mode is introduced and why it should be default behavior?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim , great questions. Let me address those.

  1. There are non-cuda GPU's/accelerators ( e.g. XPU, MTIA, HPU, etc.). It is a write once, run anywhere interface. Hence, model code would run in any of the supported accelerators without requiring surgery.

  2. As these are distributed example codes, a way to see what is happening in the distributed layer should be very informative. It can be bracketed by an input option also, if that is better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok yeah this makes sense, we can merge this if you fix breakage in CI job and make the Comm Debug optional in another PR

@githubsgi
Copy link
Author

githubsgi commented Jun 25, 2025

Looks like the failing cuda test below ( [Run Distributed Examples / test (pull_request) is done with a relatively old version of PyTorch ( torch==2.4.0.dev20240605+cu11 ). The upcoming release is 2,8 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants